{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "# Quantization aware training in Keras example\n",
    "\n",
    "## Overview\n",
    "\n",
    "Welcome to an end-to-end example for *quantization aware training*.\n",
    "\n",
    "## Learning Objectives\n",
    "\n",
    "1. Train a tf.keras model for MNIST from scratch.\n",
    "2. Fine tune the model by applying the quantization aware training API, see the accuracy, and export a quantization aware model.\n",
    "3. Use the model to create an actually quantized model for the TFLite backend.\n",
    "4. See the persistence of accuracy in TFLite and a 4x smaller model. To see the latency benefits on mobile, try out the TFLite examples in the TFLite app repository.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "## Introduction \n",
    "Quantization aware training emulates inference-time quantization, creating a model that downstream tools will use to produce actually quantized models. The quantized models use lower-precision (e.g. 8-bit instead of 32-bit float), leading to benefits during deployment.\n",
    "\n",
    " Each learning objective will correspond to a _#TODO_ in this student lab notebook -- try to complete this notebook first and then review the [solution notebook](../solutions/training_example.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yEAZYXvZU_XG"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-11-14T02:23:08.155226Z",
     "iopub.status.busy": "2020-11-14T02:23:08.154526Z",
     "iopub.status.idle": "2020-11-14T02:23:15.372000Z",
     "shell.execute_reply": "2020-11-14T02:23:15.372549Z"
    },
    "id": "OGNpmn43C0O6"
   },
   "outputs": [],
   "source": [
    "! pip uninstall -y tensorflow\n",
    "! pip install -q tensorflow-model-optimization\n",
    "! pip install --upgrade tensorflow==2.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-12T11:13:23.189158Z",
     "iopub.status.busy": "2020-09-12T11:13:23.188443Z",
     "iopub.status.idle": "2020-09-12T11:13:30.077432Z",
     "shell.execute_reply": "2020-09-12T11:13:30.077997Z"
    },
    "id": "yJwIonXEVJo6"
   },
   "outputs": [],
   "source": [
    "import tempfile\n",
    "import os\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow import keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook uses TF2.x.\n",
    "Please check your tensorflow version using the cell below."
   ]
  },
  {
 "cell_type": "code",
 "execution_count": 3,
 "metadata": {},
 "outputs": [
  {
   "name": "stdout",
   "output_type": "stream",
   "text": [
    "TensorFlow version:  2.6.0\n"
   ]
  }
 ],
 "source": [
  "# Show the currently installed version of TensorFlow\n",
  "print(\"TensorFlow version: \",tf.version.VERSION)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "psViY5PRDurp"
   },
   "source": [
    "## Train a model for MNIST without quantization aware training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-12T11:13:30.085609Z",
     "iopub.status.busy": "2020-09-12T11:13:30.084862Z",
     "iopub.status.idle": "2020-09-12T11:13:38.723898Z",
     "shell.execute_reply": "2020-09-12T11:13:38.724344Z"
    },
    "id": "pbY-KGMPvbW9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
      "11493376/11490434 [==============================] - 0s 0us/step\n",
      "11501568/11490434 [==============================] - 0s 0us/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-01-24 10:06:54.339729: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nccl2/lib:/usr/local/cuda/extras/CUPTI/lib64\n",
      "2022-01-24 10:06:54.339778: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)\n",
      "2022-01-24 10:06:54.339804: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (tensorflow-2-6-20220124-152955): /proc/driver/nvidia/version does not exist\n",
      "2022-01-24 10:06:54.340107: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2022-01-24 10:06:54.596518: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1688/1688 [==============================] - 12s 7ms/step - loss: 0.3103 - accuracy: 0.9129 - val_loss: 0.1387 - val_accuracy: 0.9632\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f5860827f50>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load MNIST dataset\n",
    "mnist = keras.datasets.mnist\n",
    "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
    "\n",
    "# Normalize the input image so that each pixel value is between 0 to 1.\n",
    "train_images = train_images / 255.0\n",
    "test_images = test_images / 255.0\n",
    "\n",
    "# Define the model architecture.\n",
    "\n",
    "# TODO: Your code goes here\n",
    "\n",
    "\n",
    "# Train the digit classification model\n",
    "\n",
    "# TODO: Your code goes here\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "K8747K9OE72P"
   },
   "source": [
    "## Clone and fine-tune pre-trained model with quantization aware training\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "F19k7ExXF_h2"
   },
   "source": [
    "### Define the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JsZROpNYMWQ0"
   },
   "source": [
    "You will apply quantization aware training to the whole model and see this in the model summary. All layers are now prefixed by \"quant\".\n",
    "\n",
    "Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8). The sections after show how to create a quantized model from the quantization aware one.\n",
    "\n",
    "In the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md), you can see how to quantize some layers for model accuracy improvements."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-12T11:13:38.729315Z",
     "iopub.status.busy": "2020-09-12T11:13:38.728674Z",
     "iopub.status.idle": "2020-09-12T11:13:40.191848Z",
     "shell.execute_reply": "2020-09-12T11:13:40.192270Z"
    },
    "id": "oq6blGjgFDCW"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "quantize_layer (QuantizeLaye (None, 28, 28)            3         \n",
      "_________________________________________________________________\n",
      "quant_reshape (QuantizeWrapp (None, 28, 28, 1)         1         \n",
      "_________________________________________________________________\n",
      "quant_conv2d (QuantizeWrappe (None, 26, 26, 12)        147       \n",
      "_________________________________________________________________\n",
      "quant_max_pooling2d (Quantiz (None, 13, 13, 12)        1         \n",
      "_________________________________________________________________\n",
      "quant_flatten (QuantizeWrapp (None, 2028)              1         \n",
      "_________________________________________________________________\n",
      "quant_dense (QuantizeWrapper (None, 10)                20295     \n",
      "=================================================================\n",
      "Total params: 20,448\n",
      "Trainable params: 20,410\n",
      "Non-trainable params: 38\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "import tensorflow_model_optimization as tfmot\n",
    "\n",
    "quantize_model = tfmot.quantization.keras.quantize_model\n",
    "\n",
    "# q_aware stands for for quantization aware.\n",
    "q_aware_model = quantize_model(model)\n",
    "\n",
    "# `quantize_model` requires a recompile.\n",
    "q_aware_model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "q_aware_model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uDr2ijwpGCI-"
   },
   "source": [
    "### Train and evaluate the model against baseline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XUBEn94hXYB1"
   },
   "source": [
    "To demonstrate fine tuning after training the model for just an epoch, fine tune with quantization aware training on a subset of the training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-12T11:13:40.197720Z",
     "iopub.status.busy": "2020-09-12T11:13:40.197034Z",
     "iopub.status.idle": "2020-09-12T11:13:40.938811Z",
     "shell.execute_reply": "2020-09-12T11:13:40.939194Z"
    },
    "id": "_PHDGJryE31X"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "1/2 [==============>...............] - ETA: 0s - loss: 0.1139 - accuracy: 0.9580"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "2/2 [==============================] - 1s 167ms/step - loss: 0.1298 - accuracy: 0.9571 - val_loss: 0.1671 - val_accuracy: 0.9600\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f376c6db7b8>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_images_subset = train_images[0:1000] # out of 60000\n",
    "train_labels_subset = train_labels[0:1000]\n",
    "\n",
    "q_aware_model.fit(train_images_subset, train_labels_subset,\n",
    "                  batch_size=500, epochs=1, validation_split=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-byC2lYlMkfN"
   },
   "source": [
    "For this example, there is minimal to no loss in test accuracy after quantization aware training, compared to the baseline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-12T11:13:40.944142Z",
     "iopub.status.busy": "2020-09-12T11:13:40.943546Z",
     "iopub.status.idle": "2020-09-12T11:13:42.104104Z",
     "shell.execute_reply": "2020-09-12T11:13:42.104515Z"
    },
    "id": "6bMFTKSSHyyZ"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Baseline test accuracy: 0.9609000086784363\n",
      "Quant test accuracy: 0.9628999829292297\n"
     ]
    }
   ],
   "source": [
    "_, baseline_model_accuracy = model.evaluate(\n",
    "    test_images, test_labels, verbose=0)\n",
    "\n",
    "_, q_aware_model_accuracy = q_aware_model.evaluate(\n",
    "   test_images, test_labels, verbose=0)\n",
    "\n",
    "print('Baseline test accuracy:', baseline_model_accuracy)\n",
    "print('Quant test accuracy:', q_aware_model_accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2IepmUPSITn6"
   },
   "source": [
    "## Create quantized model for TFLite backend"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1FgNP4rbOLH8"
   },
   "source": [
    "After this, you have an actually quantized model with int8 weights and uint8 activations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-12T11:13:42.111077Z",
     "iopub.status.busy": "2020-09-12T11:13:42.110438Z",
     "iopub.status.idle": "2020-09-12T11:13:43.728444Z",
     "shell.execute_reply": "2020-09-12T11:13:43.727865Z"
    },
    "id": "w7fztWsAOHTz"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:109: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:109: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/tmpoct8ii0p/assets\n"
     ]
    }
   ],
   "source": [
    "converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)\n",
    "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
    "\n",
    "quantized_tflite_model = converter.convert()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BEYsyYVqNgeY"
   },
   "source": [
    "## See persistence of accuracy from TF to TFLite"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "saadXD4JQsBK"
   },
   "source": [
    "Define a helper function to evaluate the TF Lite model on the test dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-12T11:13:43.737938Z",
     "iopub.status.busy": "2020-09-12T11:13:43.737258Z",
     "iopub.status.idle": "2020-09-12T11:13:43.738944Z",
     "shell.execute_reply": "2020-09-12T11:13:43.739320Z"
    },
    "id": "b8yBouuGNqls"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def evaluate_model(interpreter):\n",
    "  input_index = interpreter.get_input_details()[0][\"index\"]\n",
    "  output_index = interpreter.get_output_details()[0][\"index\"]\n",
    "\n",
    "  # Run predictions on every image in the \"test\" dataset.\n",
    "  prediction_digits = []\n",
    "  for i, test_image in enumerate(test_images):\n",
    "    if i % 1000 == 0:\n",
    "      print('Evaluated on {n} results so far.'.format(n=i))\n",
    "    # Pre-processing: add batch dimension and convert to float32 to match with\n",
    "    # the model's input data format.\n",
    "\n",
    "    # TODO: Your code goes here\n",
    "\n",
    "\n",
    "    # Run inference.\n",
    "    interpreter.invoke()\n",
    "\n",
    "    # Post-processing: remove batch dimension and find the digit with highest\n",
    "    # probability.\n",
    "    output = interpreter.tensor(output_index)\n",
    "    digit = np.argmax(output()[0])\n",
    "    prediction_digits.append(digit)\n",
    "\n",
    "  print('\\n')\n",
    "  # Compare prediction results with ground truth labels to calculate accuracy.\n",
    "  prediction_digits = np.array(prediction_digits)\n",
    "  accuracy = (prediction_digits == test_labels).mean()\n",
    "  return accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TuEFS4CIQvUw"
   },
   "source": [
    "You evaluate the quantized model and see that the accuracy from TensorFlow persists to the TFLite backend."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-12T11:13:43.743998Z",
     "iopub.status.busy": "2020-09-12T11:13:43.743384Z",
     "iopub.status.idle": "2020-09-12T11:13:47.596242Z",
     "shell.execute_reply": "2020-09-12T11:13:47.595653Z"
    },
    "id": "VqQTyqz4NsWd"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated on 0 results so far.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated on 1000 results so far.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated on 2000 results so far.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated on 3000 results so far.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated on 4000 results so far.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated on 5000 results so far.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated on 6000 results so far.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated on 7000 results so far.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated on 8000 results so far.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated on 9000 results so far.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Quant TFLite test_accuracy: 0.963\n",
      "Quant TF test accuracy: 0.9628999829292297\n"
     ]
    }
   ],
   "source": [
    "interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)\n",
    "interpreter.allocate_tensors()\n",
    "\n",
    "test_accuracy = evaluate_model(interpreter)\n",
    "\n",
    "print('Quant TFLite test_accuracy:', test_accuracy)\n",
    "print('Quant TF test accuracy:', q_aware_model_accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z8D7WnFF5DZR"
   },
   "source": [
    "## See 4x smaller model from quantization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "I1c2IecBRCdQ"
   },
   "source": [
    "You create a float TFLite model and then see that the quantized TFLite model\n",
    "is 4x smaller."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-12T11:13:47.603575Z",
     "iopub.status.busy": "2020-09-12T11:13:47.602885Z",
     "iopub.status.idle": "2020-09-12T11:13:48.243685Z",
     "shell.execute_reply": "2020-09-12T11:13:48.243157Z"
    },
    "id": "jy_Lgfh8VkyX"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/tmpy0445k6u/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/tmpy0445k6u/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Float model in Mb: 0.08053970336914062\n",
      "Quantized model in Mb: 0.02339935302734375\n"
     ]
    }
   ],
   "source": [
    "# Create float TFLite model.\n",
    "# TODO: Your code goes here\n",
    "\n",
    "\n",
    "# Measure sizes of models.\n",
    "_, float_file = tempfile.mkstemp('.tflite')\n",
    "_, quant_file = tempfile.mkstemp('.tflite')\n",
    "\n",
    "with open(quant_file, 'wb') as f:\n",
    "  f.write(quantized_tflite_model)\n",
    "\n",
    "with open(float_file, 'wb') as f:\n",
    "  f.write(float_tflite_model)\n",
    "\n",
    "print(\"Float model in Mb:\", os.path.getsize(float_file) / float(2**20))\n",
    "print(\"Quantized model in Mb:\", os.path.getsize(quant_file) / float(2**20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0O5xuci-SonI"
   },
   "source": [
    "## Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "O2I7xmyMW5QY"
   },
   "source": [
    "In this tutorial, you saw how to create quantization aware models with the TensorFlow Model Optimization Toolkit API and then quantized models for the TFLite backend.\n",
    "\n",
    "You saw a 4x model size compression benefit for a model for MNIST, with minimal accuracy\n",
    "difference. To see the latency benefits on mobile, try out the TFLite examples [in the TFLite app repository](https://www.tensorflow.org/lite/models).\n",
    "\n",
    "We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments. \n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "training_example.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
